{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convolutional Neural Networks\n",
    "## Standard LeNet5 with PyTorch\n",
    "### Xavier Bresson, Sept. 2017"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of original LeNet5 Convolutional Neural Networks:<br>\n",
    "Gradient-based learning applied to document recognition<br>\n",
    "Y LeCun, L Bottou, Y Bengio, P Haffner<br>\n",
    "Proceedings of the IEEE 86 (11), 2278-2324<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda available\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn\n",
    "import pdb #pdb.set_trace()\n",
    "import collections\n",
    "import time\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    print('cuda available')\n",
    "    dtypeFloat = torch.cuda.FloatTensor\n",
    "    dtypeLong = torch.cuda.LongTensor\n",
    "    torch.cuda.manual_seed(1)\n",
    "else:\n",
    "    print('cuda not available')\n",
    "    dtypeFloat = torch.FloatTensor\n",
    "    dtypeLong = torch.LongTensor\n",
    "    torch.manual_seed(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting datasets/train-images-idx3-ubyte.gz\n",
      "Extracting datasets/train-labels-idx1-ubyte.gz\n",
      "Extracting datasets/t10k-images-idx3-ubyte.gz\n",
      "Extracting datasets/t10k-labels-idx1-ubyte.gz\n",
      "(55000, 784)\n",
      "(55000,)\n",
      "(5000, 784)\n",
      "(5000,)\n",
      "(10000, 784)\n",
      "(10000,)\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets('datasets', one_hot=False) # load data in folder datasets/\n",
    "\n",
    "train_data = mnist.train.images.astype(np.float32)\n",
    "val_data = mnist.validation.images.astype(np.float32)\n",
    "test_data = mnist.test.images.astype(np.float32)\n",
    "train_labels = mnist.train.labels\n",
    "val_labels = mnist.validation.labels\n",
    "test_labels = mnist.test.labels\n",
    "print(train_data.shape)\n",
    "print(train_labels.shape)\n",
    "print(val_data.shape)\n",
    "print(val_labels.shape)\n",
    "print(test_data.shape)\n",
    "print(test_labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ConvNet LeNet5\n",
    "### Layers: CL32-MP4-CL64-MP4-FC512-FC10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# class definition\n",
    "class ConvNet_LeNet5(nn.Module):\n",
    "    \n",
    "    def __init__(self, net_parameters):\n",
    "        \n",
    "        print('ConvNet: LeNet5\\n')\n",
    "        \n",
    "        super(ConvNet_LeNet5, self).__init__()\n",
    "        \n",
    "        Nx, Ny, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F, FC2_F = net_parameters\n",
    "        FC1Fin = CL2_F*(Nx//4)**2\n",
    "        \n",
    "        # graph CL1\n",
    "        self.conv1 = nn.Conv2d(1, CL1_F, CL1_K, padding=(2, 2))\n",
    "        Fin = CL1_K**2; Fout = CL1_F;\n",
    "        scale = np.sqrt( 2.0/ (Fin+Fout) )\n",
    "        self.conv1.weight.data.uniform_(-scale, scale)\n",
    "        self.conv1.bias.data.fill_(0.0)\n",
    "        \n",
    "        # graph CL2\n",
    "        self.conv2 = nn.Conv2d(CL1_F, CL2_F, CL2_K, padding=(2, 2))\n",
    "        Fin = CL1_F*CL2_K**2; Fout = CL2_F;\n",
    "        scale = np.sqrt( 2.0/ (Fin+Fout) )\n",
    "        self.conv2.weight.data.uniform_(-scale, scale)\n",
    "        self.conv2.bias.data.fill_(0.0)\n",
    "        \n",
    "        # FC1\n",
    "        self.fc1 = nn.Linear(FC1Fin, FC1_F) \n",
    "        Fin = FC1Fin; Fout = FC1_F;\n",
    "        scale = np.sqrt( 2.0/ (Fin+Fout) )\n",
    "        self.fc1.weight.data.uniform_(-scale, scale)\n",
    "        self.fc1.bias.data.fill_(0.0)\n",
    "        self.FC1Fin = FC1Fin\n",
    "        \n",
    "        # FC2\n",
    "        self.fc2 = nn.Linear(FC1_F, FC2_F)\n",
    "        Fin = FC1_F; Fout = FC2_F;\n",
    "        scale = np.sqrt( 2.0/ (Fin+Fout) )\n",
    "        self.fc2.weight.data.uniform_(-scale, scale)\n",
    "        self.fc2.bias.data.fill_(0.0)\n",
    "        \n",
    "        # max pooling\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "         \n",
    "        \n",
    "    def forward(self, x, d):\n",
    "        \n",
    "        # CL1\n",
    "        x = self.conv1(x)    \n",
    "        x = F.relu(x)\n",
    "        x = self.pool(x)\n",
    "\n",
    "        # CL2\n",
    "        x = self.conv2(x)    \n",
    "        x = F.relu(x)\n",
    "        x = self.pool(x)\n",
    "\n",
    "        # FC1\n",
    "        x = x.permute(0,3,2,1).contiguous() # reshape from pytorch array to tensorflow array\n",
    "        x = x.view(-1, self.FC1Fin)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x  = nn.Dropout(d)(x)\n",
    "        \n",
    "        # FC2\n",
    "        x = self.fc2(x)\n",
    "            \n",
    "        return x\n",
    "        \n",
    "        \n",
    "    def loss(self, y, y_target, l2_regularization):\n",
    "    \n",
    "        loss = nn.CrossEntropyLoss()(y,y_target)\n",
    "\n",
    "        l2_loss = 0.0\n",
    "        for param in self.parameters():\n",
    "            data = param* param\n",
    "            l2_loss += data.sum()\n",
    "           \n",
    "        loss += 0.5* l2_regularization* l2_loss\n",
    "            \n",
    "        return loss\n",
    "    \n",
    "    \n",
    "    def update(self, lr):\n",
    "                \n",
    "        update = torch.optim.SGD( self.parameters(), lr=lr, momentum=0.9 )\n",
    "        \n",
    "        return update\n",
    "        \n",
    "           \n",
    "    def update_learning_rate(self, optimizer, lr):\n",
    "   \n",
    "        for param_group in optimizer.param_groups:\n",
    "            param_group['lr'] = lr\n",
    "\n",
    "        return optimizer\n",
    "\n",
    "    \n",
    "    def evaluation(self, y_predicted, test_l):\n",
    "    \n",
    "        _, class_predicted = torch.max(y_predicted.data, 1)\n",
    "        return 100.0* (class_predicted == test_l).sum()/ y_predicted.size(0)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No existing network to delete\n",
      "\n",
      "ConvNet: LeNet5\n",
      "\n",
      "ConvNet_LeNet5 (\n",
      "  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "  (fc1): Linear (3136 -> 512)\n",
      "  (fc2): Linear (512 -> 10)\n",
      "  (pool): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      ")\n",
      "num_epochs= 20 , train_size= 55000 , nb_iter= 11000\n",
      "epoch= 1, i=  100, loss(batch)= 0.2640, accuray(batch)= 94.00\n",
      "epoch= 1, i=  200, loss(batch)= 0.2147, accuray(batch)= 97.00\n",
      "epoch= 1, i=  300, loss(batch)= 0.2581, accuray(batch)= 94.00\n",
      "epoch= 1, i=  400, loss(batch)= 0.1280, accuray(batch)= 100.00\n",
      "epoch= 1, i=  500, loss(batch)= 0.2204, accuray(batch)= 96.00\n",
      "epoch= 1, loss(train)= 0.305, accuracy(train)= 93.693, time= 7.688, lr= 0.05000\n",
      "  accuracy(test) = 98.660 %, time= 0.331\n",
      "epoch= 2, i=  100, loss(batch)= 0.1557, accuray(batch)= 98.00\n",
      "epoch= 2, i=  200, loss(batch)= 0.1801, accuray(batch)= 97.00\n",
      "epoch= 2, i=  300, loss(batch)= 0.1663, accuray(batch)= 97.00\n",
      "epoch= 2, i=  400, loss(batch)= 0.1095, accuray(batch)= 99.00\n",
      "epoch= 2, i=  500, loss(batch)= 0.1311, accuray(batch)= 99.00\n",
      "epoch= 2, loss(train)= 0.153, accuracy(train)= 98.324, time= 6.966, lr= 0.04750\n",
      "  accuracy(test) = 98.670 %, time= 0.331\n",
      "epoch= 3, i=  100, loss(batch)= 0.0953, accuray(batch)= 100.00\n",
      "epoch= 3, i=  200, loss(batch)= 0.0909, accuray(batch)= 100.00\n",
      "epoch= 3, i=  300, loss(batch)= 0.1151, accuray(batch)= 98.00\n",
      "epoch= 3, i=  400, loss(batch)= 0.0906, accuray(batch)= 100.00\n",
      "epoch= 3, i=  500, loss(batch)= 0.1811, accuray(batch)= 96.00\n",
      "epoch= 3, loss(train)= 0.128, accuracy(train)= 98.700, time= 6.985, lr= 0.04512\n",
      "  accuracy(test) = 99.020 %, time= 0.331\n",
      "epoch= 4, i=  100, loss(batch)= 0.0937, accuray(batch)= 100.00\n",
      "epoch= 4, i=  200, loss(batch)= 0.1036, accuray(batch)= 100.00\n",
      "epoch= 4, i=  300, loss(batch)= 0.1260, accuray(batch)= 99.00\n",
      "epoch= 4, i=  400, loss(batch)= 0.1336, accuray(batch)= 96.00\n",
      "epoch= 4, i=  500, loss(batch)= 0.1960, accuray(batch)= 98.00\n",
      "epoch= 4, loss(train)= 0.110, accuracy(train)= 99.016, time= 6.964, lr= 0.04287\n",
      "  accuracy(test) = 99.140 %, time= 0.331\n",
      "epoch= 5, i=  100, loss(batch)= 0.1039, accuray(batch)= 98.00\n",
      "epoch= 5, i=  200, loss(batch)= 0.0830, accuray(batch)= 100.00\n",
      "epoch= 5, i=  300, loss(batch)= 0.1148, accuray(batch)= 99.00\n",
      "epoch= 5, i=  400, loss(batch)= 0.0776, accuray(batch)= 100.00\n",
      "epoch= 5, i=  500, loss(batch)= 0.0837, accuray(batch)= 99.00\n",
      "epoch= 5, loss(train)= 0.098, accuracy(train)= 99.133, time= 6.945, lr= 0.04073\n",
      "  accuracy(test) = 99.140 %, time= 0.332\n",
      "epoch= 6, i=  100, loss(batch)= 0.0686, accuray(batch)= 100.00\n",
      "epoch= 6, i=  200, loss(batch)= 0.0952, accuray(batch)= 99.00\n",
      "epoch= 6, i=  300, loss(batch)= 0.0655, accuray(batch)= 100.00\n",
      "epoch= 6, i=  400, loss(batch)= 0.0852, accuray(batch)= 99.00\n",
      "epoch= 6, i=  500, loss(batch)= 0.0977, accuray(batch)= 98.00\n",
      "epoch= 6, loss(train)= 0.089, accuracy(train)= 99.222, time= 6.999, lr= 0.03869\n",
      "  accuracy(test) = 99.250 %, time= 0.332\n",
      "epoch= 7, i=  100, loss(batch)= 0.0697, accuray(batch)= 99.00\n",
      "epoch= 7, i=  200, loss(batch)= 0.0779, accuray(batch)= 100.00\n",
      "epoch= 7, i=  300, loss(batch)= 0.0649, accuray(batch)= 100.00\n",
      "epoch= 7, i=  400, loss(batch)= 0.0802, accuray(batch)= 99.00\n",
      "epoch= 7, i=  500, loss(batch)= 0.0695, accuray(batch)= 100.00\n",
      "epoch= 7, loss(train)= 0.081, accuracy(train)= 99.282, time= 6.983, lr= 0.03675\n",
      "  accuracy(test) = 99.250 %, time= 0.332\n",
      "epoch= 8, i=  100, loss(batch)= 0.0647, accuray(batch)= 99.00\n",
      "epoch= 8, i=  200, loss(batch)= 0.0621, accuray(batch)= 100.00\n",
      "epoch= 8, i=  300, loss(batch)= 0.0597, accuray(batch)= 100.00\n",
      "epoch= 8, i=  400, loss(batch)= 0.0584, accuray(batch)= 100.00\n",
      "epoch= 8, i=  500, loss(batch)= 0.0705, accuray(batch)= 99.00\n",
      "epoch= 8, loss(train)= 0.075, accuracy(train)= 99.389, time= 6.964, lr= 0.03492\n",
      "  accuracy(test) = 99.270 %, time= 0.332\n",
      "epoch= 9, i=  100, loss(batch)= 0.0586, accuray(batch)= 100.00\n",
      "epoch= 9, i=  200, loss(batch)= 0.0761, accuray(batch)= 99.00\n",
      "epoch= 9, i=  300, loss(batch)= 0.0554, accuray(batch)= 100.00\n",
      "epoch= 9, i=  400, loss(batch)= 0.0877, accuray(batch)= 98.00\n",
      "epoch= 9, i=  500, loss(batch)= 0.0639, accuray(batch)= 100.00\n",
      "epoch= 9, loss(train)= 0.070, accuracy(train)= 99.455, time= 6.962, lr= 0.03317\n",
      "  accuracy(test) = 99.320 %, time= 0.332\n",
      "epoch= 10, i=  100, loss(batch)= 0.0797, accuray(batch)= 99.00\n",
      "epoch= 10, i=  200, loss(batch)= 0.0581, accuray(batch)= 100.00\n",
      "epoch= 10, i=  300, loss(batch)= 0.0611, accuray(batch)= 99.00\n",
      "epoch= 10, i=  400, loss(batch)= 0.0520, accuray(batch)= 100.00\n",
      "epoch= 10, i=  500, loss(batch)= 0.0996, accuray(batch)= 98.00\n",
      "epoch= 10, loss(train)= 0.065, accuracy(train)= 99.482, time= 6.952, lr= 0.03151\n",
      "  accuracy(test) = 99.150 %, time= 0.332\n",
      "epoch= 11, i=  100, loss(batch)= 0.0489, accuray(batch)= 100.00\n",
      "epoch= 11, i=  200, loss(batch)= 0.0475, accuray(batch)= 100.00\n",
      "epoch= 11, i=  300, loss(batch)= 0.0547, accuray(batch)= 100.00\n",
      "epoch= 11, i=  400, loss(batch)= 0.0627, accuray(batch)= 99.00\n",
      "epoch= 11, i=  500, loss(batch)= 0.0535, accuray(batch)= 100.00\n",
      "epoch= 11, loss(train)= 0.062, accuracy(train)= 99.518, time= 7.004, lr= 0.02994\n",
      "  accuracy(test) = 99.400 %, time= 0.332\n",
      "epoch= 12, i=  100, loss(batch)= 0.0476, accuray(batch)= 100.00\n",
      "epoch= 12, i=  200, loss(batch)= 0.0720, accuray(batch)= 98.00\n",
      "epoch= 12, i=  300, loss(batch)= 0.0556, accuray(batch)= 99.00\n",
      "epoch= 12, i=  400, loss(batch)= 0.0575, accuray(batch)= 100.00\n",
      "epoch= 12, i=  500, loss(batch)= 0.0795, accuray(batch)= 99.00\n",
      "epoch= 12, loss(train)= 0.059, accuracy(train)= 99.604, time= 6.957, lr= 0.02844\n",
      "  accuracy(test) = 99.380 %, time= 0.333\n",
      "epoch= 13, i=  100, loss(batch)= 0.0440, accuray(batch)= 100.00\n",
      "epoch= 13, i=  200, loss(batch)= 0.0507, accuray(batch)= 100.00\n",
      "epoch= 13, i=  300, loss(batch)= 0.0643, accuray(batch)= 99.00\n",
      "epoch= 13, i=  400, loss(batch)= 0.0956, accuray(batch)= 98.00\n",
      "epoch= 13, i=  500, loss(batch)= 0.0489, accuray(batch)= 100.00\n",
      "epoch= 13, loss(train)= 0.056, accuracy(train)= 99.602, time= 6.963, lr= 0.02702\n",
      "  accuracy(test) = 99.360 %, time= 0.333\n",
      "epoch= 14, i=  100, loss(batch)= 0.0465, accuray(batch)= 100.00\n",
      "epoch= 14, i=  200, loss(batch)= 0.0441, accuray(batch)= 100.00\n",
      "epoch= 14, i=  300, loss(batch)= 0.0504, accuray(batch)= 100.00\n",
      "epoch= 14, i=  400, loss(batch)= 0.0433, accuray(batch)= 100.00\n",
      "epoch= 14, i=  500, loss(batch)= 0.0581, accuray(batch)= 100.00\n",
      "epoch= 14, loss(train)= 0.054, accuracy(train)= 99.605, time= 6.990, lr= 0.02567\n",
      "  accuracy(test) = 99.330 %, time= 0.333\n",
      "epoch= 15, i=  100, loss(batch)= 0.0475, accuray(batch)= 100.00\n",
      "epoch= 15, i=  200, loss(batch)= 0.0428, accuray(batch)= 100.00\n",
      "epoch= 15, i=  300, loss(batch)= 0.0437, accuray(batch)= 100.00\n",
      "epoch= 15, i=  400, loss(batch)= 0.0945, accuray(batch)= 98.00\n",
      "epoch= 15, i=  500, loss(batch)= 0.0798, accuray(batch)= 99.00\n",
      "epoch= 15, loss(train)= 0.053, accuracy(train)= 99.613, time= 6.993, lr= 0.02438\n",
      "  accuracy(test) = 99.320 %, time= 0.333\n",
      "epoch= 16, i=  100, loss(batch)= 0.0403, accuray(batch)= 100.00\n",
      "epoch= 16, i=  200, loss(batch)= 0.0571, accuray(batch)= 99.00\n",
      "epoch= 16, i=  300, loss(batch)= 0.0444, accuray(batch)= 100.00\n",
      "epoch= 16, i=  400, loss(batch)= 0.0985, accuray(batch)= 99.00\n",
      "epoch= 16, i=  500, loss(batch)= 0.0457, accuray(batch)= 100.00\n",
      "epoch= 16, loss(train)= 0.051, accuracy(train)= 99.669, time= 6.950, lr= 0.02316\n",
      "  accuracy(test) = 99.300 %, time= 0.333\n",
      "epoch= 17, i=  100, loss(batch)= 0.1177, accuray(batch)= 97.00\n",
      "epoch= 17, i=  200, loss(batch)= 0.0484, accuray(batch)= 99.00\n",
      "epoch= 17, i=  300, loss(batch)= 0.0496, accuray(batch)= 99.00\n",
      "epoch= 17, i=  400, loss(batch)= 0.0426, accuray(batch)= 100.00\n",
      "epoch= 17, i=  500, loss(batch)= 0.0460, accuray(batch)= 100.00\n",
      "epoch= 17, loss(train)= 0.049, accuracy(train)= 99.695, time= 6.899, lr= 0.02201\n",
      "  accuracy(test) = 99.380 %, time= 0.333\n",
      "epoch= 18, i=  100, loss(batch)= 0.0480, accuray(batch)= 99.00\n",
      "epoch= 18, i=  200, loss(batch)= 0.0446, accuray(batch)= 100.00\n",
      "epoch= 18, i=  300, loss(batch)= 0.0415, accuray(batch)= 100.00\n",
      "epoch= 18, i=  400, loss(batch)= 0.0408, accuray(batch)= 100.00\n",
      "epoch= 18, i=  500, loss(batch)= 0.0376, accuray(batch)= 100.00\n",
      "epoch= 18, loss(train)= 0.048, accuracy(train)= 99.704, time= 6.905, lr= 0.02091\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  accuracy(test) = 99.350 %, time= 0.333\n",
      "epoch= 19, i=  100, loss(batch)= 0.0413, accuray(batch)= 100.00\n",
      "epoch= 19, i=  200, loss(batch)= 0.0397, accuray(batch)= 100.00\n",
      "epoch= 19, i=  300, loss(batch)= 0.0434, accuray(batch)= 100.00\n",
      "epoch= 19, i=  400, loss(batch)= 0.0434, accuray(batch)= 100.00\n",
      "epoch= 19, i=  500, loss(batch)= 0.0458, accuray(batch)= 100.00\n",
      "epoch= 19, loss(train)= 0.046, accuracy(train)= 99.725, time= 6.908, lr= 0.01986\n",
      "  accuracy(test) = 99.390 %, time= 0.334\n",
      "epoch= 20, i=  100, loss(batch)= 0.0365, accuray(batch)= 100.00\n",
      "epoch= 20, i=  200, loss(batch)= 0.0454, accuray(batch)= 100.00\n",
      "epoch= 20, i=  300, loss(batch)= 0.0360, accuray(batch)= 100.00\n",
      "epoch= 20, i=  400, loss(batch)= 0.0598, accuray(batch)= 98.00\n",
      "epoch= 20, i=  500, loss(batch)= 0.0364, accuray(batch)= 100.00\n",
      "epoch= 20, loss(train)= 0.045, accuracy(train)= 99.736, time= 6.908, lr= 0.01887\n",
      "  accuracy(test) = 99.310 %, time= 0.333\n"
     ]
    }
   ],
   "source": [
    "# Delete existing network if exists\n",
    "try:\n",
    "    del net\n",
    "    print('Delete existing network\\n')\n",
    "except NameError:\n",
    "    print('No existing network to delete\\n')\n",
    "\n",
    "\n",
    "\n",
    "# network parameters\n",
    "Nx = Ny = 28\n",
    "CL1_F = 32\n",
    "CL1_K = 5\n",
    "CL2_F = 64\n",
    "CL2_K = 5\n",
    "FC1_F = 512\n",
    "FC2_F = 10\n",
    "net_parameters = [Nx, Ny, CL1_F, CL1_K, CL2_F, CL2_K, FC1_F, FC2_F]\n",
    "\n",
    "\n",
    "# instantiate the object net of the class \n",
    "net = ConvNet_LeNet5(net_parameters)\n",
    "if torch.cuda.is_available():\n",
    "    net.cuda()\n",
    "print(net)\n",
    "\n",
    "\n",
    "# Weights\n",
    "L = list(net.parameters())\n",
    "\n",
    "\n",
    "# learning parameters\n",
    "learning_rate = 0.05\n",
    "dropout_value = 0.5\n",
    "l2_regularization = 5e-4 \n",
    "batch_size = 100\n",
    "num_epochs = 20\n",
    "train_size = train_data.shape[0]\n",
    "nb_iter = int(num_epochs * train_size) // batch_size\n",
    "print('num_epochs=',num_epochs,', train_size=',train_size,', nb_iter=',nb_iter)\n",
    "\n",
    "\n",
    "# Optimizer\n",
    "global_lr = learning_rate\n",
    "global_step = 0\n",
    "decay = 0.95\n",
    "decay_steps = train_size\n",
    "lr = learning_rate\n",
    "optimizer = net.update(lr) \n",
    "\n",
    "\n",
    "# loop over epochs\n",
    "indices = collections.deque()\n",
    "for epoch in range(num_epochs):  # loop over the dataset multiple times\n",
    "\n",
    "    # reshuffle \n",
    "    indices.extend(np.random.permutation(train_size)) # rand permutation\n",
    "    \n",
    "    # reset time\n",
    "    t_start = time.time()\n",
    "    \n",
    "    # extract batches\n",
    "    running_loss = 0.0\n",
    "    running_accuray = 0\n",
    "    running_total = 0\n",
    "    while len(indices) >= batch_size:\n",
    "        \n",
    "        # extract batches\n",
    "        batch_idx = [indices.popleft() for i in range(batch_size)]\n",
    "        train_x, train_y = train_data[batch_idx,:].T, train_labels[batch_idx].T\n",
    "        train_x = np.reshape(train_x,[28,28,batch_size])[:,:,:,None]\n",
    "        train_x = np.transpose(train_x,[2,3,1,0]) # reshape from pytorch array to tensorflow array\n",
    "        train_x = Variable( torch.FloatTensor(train_x).type(dtypeFloat) , requires_grad=False) \n",
    "        train_y = train_y.astype(np.int64)\n",
    "        train_y = torch.LongTensor(train_y).type(dtypeLong)\n",
    "        train_y = Variable( train_y , requires_grad=False) \n",
    "            \n",
    "        # Forward \n",
    "        y = net.forward(train_x, dropout_value)\n",
    "        loss = net.loss(y,train_y,l2_regularization) \n",
    "        loss_train = loss.data[0]\n",
    "        \n",
    "        # Accuracy\n",
    "        acc_train = net.evaluation(y,train_y.data)\n",
    "        \n",
    "        # backward\n",
    "        loss.backward()\n",
    "        \n",
    "        # Update \n",
    "        global_step += batch_size # to update learning rate\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # loss, accuracy\n",
    "        running_loss += loss_train\n",
    "        running_accuray += acc_train\n",
    "        running_total += 1\n",
    "        \n",
    "        # print        \n",
    "        if not running_total%100: # print every x mini-batches\n",
    "            print('epoch= %d, i= %4d, loss(batch)= %.4f, accuray(batch)= %.2f' % (epoch+1, running_total, loss_train, acc_train))\n",
    "          \n",
    "       \n",
    "    # print \n",
    "    t_stop = time.time() - t_start\n",
    "    print('epoch= %d, loss(train)= %.3f, accuracy(train)= %.3f, time= %.3f, lr= %.5f' % \n",
    "          (epoch+1, running_loss/running_total, running_accuray/running_total, t_stop, lr))\n",
    " \n",
    "\n",
    "    # update learning rate \n",
    "    lr = global_lr * pow( decay , float(global_step// decay_steps) )\n",
    "    optimizer = net.update_learning_rate(optimizer, lr)\n",
    "    \n",
    "    \n",
    "    # Test set\n",
    "    running_accuray_test = 0\n",
    "    running_total_test = 0\n",
    "    indices_test = collections.deque()\n",
    "    indices_test.extend(range(test_data.shape[0]))\n",
    "    t_start_test = time.time()\n",
    "    while len(indices_test) >= batch_size:\n",
    "        batch_idx_test = [indices_test.popleft() for i in range(batch_size)]\n",
    "        test_x, test_y = test_data[batch_idx_test,:].T, test_labels[batch_idx_test].T\n",
    "        test_x = np.reshape(test_x,[28,28,batch_size])[:,:,:,None]\n",
    "        test_x = np.transpose(test_x,[2,3,1,0]) # reshape from pytorch array to tensorflow array\n",
    "        test_x = Variable( torch.FloatTensor(test_x).type(dtypeFloat) , requires_grad=False) \n",
    "        y = net.forward(test_x, 0.0) \n",
    "        test_y = test_y.astype(np.int64)\n",
    "        test_y = torch.LongTensor(test_y).type(dtypeLong)\n",
    "        test_y = Variable( test_y , requires_grad=False) \n",
    "        acc_test = net.evaluation(y,test_y.data)\n",
    "        running_accuray_test += acc_test\n",
    "        running_total_test += 1\n",
    "    t_stop_test = time.time() - t_start_test\n",
    "    print('  accuracy(test) = %.3f %%, time= %.3f' % (running_accuray_test / running_total_test, t_stop_test))  \n",
    "\n",
    "    \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
